# -*- coding: utf-8 -*-
"""
Created on Mon Apr 10 23:16:30 2023

@author: 29672366
"""

import torch
import torch.nn as nn
import numpy as np
import math
import torch.nn.functional as F
import random


class TransformerQA(nn.Module):
    def __init__(self, vocab_size, 
                 embedding_dim, 
                 hidden_dim, 
                 num_layers, 
                 num_heads, 
                 max_seq_len,
                 dropout = 0.2):
        super(TransformerQA, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # 初始化解码器
        decoder_layer = nn.TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc = nn.Linear(embedding_dim, vocab_size)
        
        self.max_seq_len = max_seq_len
        self.vocab_size = vocab_size
        nn.init.xavier_uniform_(self.fc.weight)

    def forward(self, input_seq, target_seq = None, teachforcing = 0.5,temperature = 0.7):
        # input_seq: (batch_size, seq_len)
        # target_seq: (batch_size, seq_len)
        #print(input_seq.shape)
        #print(target_seq.shape)
        #print("input_seq:",input_seq.to('cpu'))
        #print("target_seq:",target_seq.to('cpu'))
        
        #生成序列
        output = self.generate_output_sequence(input_seq, target_seq, teachforcing, temperature)
        return output
    
    def encode(self,seq):
        #词嵌入并转置
        emb = self.embedding(seq).transpose(0, 1) # (seq_len, batch_size, embedding_dim)
        return emb

    def generate_output_sequence(self,input_seq, target_seq = None, teachforcing = 0.5,temperature = 0.7):
        # input_seq: (batch_size, seq_len)
        # target_seq: (batch_size, seq_len)
        
        #编码
        encoded_input = self.encode(input_seq) # 编码器的输出 (seq_len, batch_size, embedding_dim)
        encoded_target = None
        
        START_TOKEN = torch.full((1, encoded_input.size(1), encoded_input.size(2)), -1.).to(encoded_input.device) #以-1矩阵值作为start token

        if target_seq is not None: # 如果提供了目标序列，则将其长度用作生成序列的长度
            encoded_target = self.encode(target_seq) # 编码器的输出 (seq_len, batch_size, embedding_dim)
            batch_size, target_len = target_seq.size()
            #if target_len < self.max_seq_len:#停止向量0已加在了数据集中
            #    target_len += 1 #多生成一个，训练时相应增加了0向量
        else: # 否则，将最大序列长度用作生成序列的长度
            target_len = self.max_seq_len
        
        # next_token:(1, batch_size, embedding_dim)
        
        startrange = 0#random.randint(0,10)
        #不管target_seq是不是None 都以START_TOKEN开始预测 位置编码长度需要相应+1 第0位开始使用 后面的循环要从i+1开始
        output_seq = self.decode_with_pos_enc(START_TOKEN, encoded_input, startrange + target_len, 0)
        for i in range(0,target_len - 1):
            next_token = None
            if encoded_target is not None:
                if teachforcing > 0  and teachforcing <= 1:
                    
                    p = np.array([teachforcing,1-teachforcing])#教,不教
                    bool_teach = np.random.choice([True,False], p = p.ravel())
                    if bool_teach:
                        next_token = encoded_target[i, :, :].unsqueeze(0)
                        
            if next_token == None:
                next_token = output_seq[-1:, :, :]
                
            #使用位置编码传入生成下一个词汇
            output_i = self.decode_with_pos_enc(next_token, encoded_input, startrange + target_len, startrange + i + 1) # 在输出序列的最后一个记号上添加位置编码
            #output_i:(1, batch_size, embedding_dim)
            #if i < 5:
            #    print(output_i)
            output_seq = torch.cat([output_seq, output_i], dim=0) # 将当前输出连接到输出序列中

            # 如果未提供目标序列并且已生成停止标记，则提前停止生成
            if encoded_target is None and torch.argmax(F.softmax(self.fc(output_i),dim=2)) == 13:#因为前面2维都是1 所以这里可以不用转置
                break
            
        #无须去除开始标记，因为开始标记不在其中
        
        output_seq = output_seq.transpose(0, 1) # 转置成 (batch_size, seq_len, embedding_dim) 的形状
        output_seq = self.fc(output_seq) # 输出层的输出 (batch_size, seq_len, vocab_size) 
        return output_seq / temperature


    def decode_with_pos_enc(self, last_output, encoded_input, target_len, indices):
        # output: (1, batch_size, embedding_dim)
        # encoded: (seq_len, batch_size, embedding_dim)
        pos_enc = self.position_encoding(target_len, last_output.size(-1)).to(encoded_input.device) # 获取位置编码
        #pos_enc: (batch_size:1, seq_len, embedding_dim)
        output_with_pos = last_output + pos_enc[:, indices, :]# 将位置编码加到词向量上
        decoded_output = self.decoder(output_with_pos, encoded_input) # 解码器的输出 (1, batch_size, embedding_dim)
        return decoded_output

    def position_encoding(self, seq_len, hidden_size):
        pos = torch.arange(seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2) * (-math.log(10000.0) / hidden_size))
        pos_enc = torch.zeros((1, seq_len, hidden_size))
        pos_enc[:, :, 0::2] = torch.sin(pos * div_term.unsqueeze(0)) # 计算正弦部分
        pos_enc[:, :, 1::2] = torch.cos(pos * div_term.unsqueeze(0)) # 计算余弦部分
        return pos_enc